{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch.autograd import Variable\n",
    "import torch.nn as nn\n",
    "import torch.nn.init as init\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "import numpy as np\n",
    "from six.moves import cPickle as pickle\n",
    "from six.moves import range\n",
    "\n",
    "import time"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train set (5679, 1, 64, 64) (5679, 6)\n"
     ]
    }
   ],
   "source": [
    "pickle_file = 'SVHN_1x64x64_valid.pickle'\n",
    "\n",
    "with open(pickle_file, 'rb') as f:\n",
    "    save = pickle.load(f)\n",
    "    train_dataset = save['valid_dataset']\n",
    "    train_labels = save['valid_labels']\n",
    "    del save  # hint to help gc free up memory\n",
    "    print('train set', train_dataset.shape, train_labels.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1545\n",
      "2988\n",
      "985\n"
     ]
    }
   ],
   "source": [
    "c3 = []\n",
    "c2 = []\n",
    "c1 = []\n",
    "\n",
    "for i in range(5679):\n",
    "    categ = train_labels[i][0]\n",
    "    if(categ == 3):\n",
    "        c3.append(i)\n",
    "    elif(categ == 2):\n",
    "        c2.append(i)\n",
    "    elif(categ == 1):\n",
    "        c1.append(i)\n",
    "\n",
    "print(len(c3))\n",
    "print(len(c2))\n",
    "print(len(c1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(1545, 1, 64, 64)\n",
      "(1545, 6)\n"
     ]
    }
   ],
   "source": [
    "c3_data = train_dataset[c3]\n",
    "c3_target = train_labels[c3]\n",
    "print(c3_data.shape)\n",
    "print(c3_target.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(2988, 1, 64, 64)\n",
      "(2988, 6)\n"
     ]
    }
   ],
   "source": [
    "c2_data = train_dataset[c2]\n",
    "c2_target = train_labels[c2]\n",
    "print(c2_data.shape)\n",
    "print(c2_target.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(985, 1, 64, 64)\n",
      "(985, 6)\n"
     ]
    }
   ],
   "source": [
    "c1_data = train_dataset[c1]\n",
    "c1_target = train_labels[c1]\n",
    "print(c1_data.shape)\n",
    "print(c1_target.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(5518, 1, 64, 64) (5518, 6)\n"
     ]
    }
   ],
   "source": [
    "numpy_data = np.concatenate((c1_data, c2_data, c3_data))\n",
    "numpy_target = np.concatenate((c1_target, c2_target, c3_target))\n",
    "print(numpy_data.shape, numpy_target.shape )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "del c1_data, c2_data, c3_data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "del c1_target, c2_target, c3_target"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "del train_dataset, train_labels"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "del c3, c2, c1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "rng_state = np.random.get_state()\n",
    "X = np.random.permutation(numpy_data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "np.random.set_state(rng_state)\n",
    "Y = np.random.permutation(numpy_target)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "rng_state = np.random.get_state()\n",
    "X = np.random.permutation(X)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "np.random.set_state(rng_state)\n",
    "Y = np.random.permutation(Y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import os"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Compressed pickle size: 90539623\n"
     ]
    }
   ],
   "source": [
    "pickle_file = 'SVHN_digits_3_2_1_64x64x1_valid.pickle'\n",
    "try:\n",
    "    f = open(pickle_file, 'wb')\n",
    "    save = {\n",
    "        'valid_data': X,\n",
    "        'valid_target': Y\n",
    "            }\n",
    "    pickle.dump(save, f, pickle.HIGHEST_PROTOCOL)\n",
    "    f.close()\n",
    "except Exception as e:\n",
    "    print('Unable to save data to', pickle_file, ':', e)\n",
    "    raise\n",
    "statinfo = os.stat(pickle_file)\n",
    "print('Compressed pickle size:', statinfo.st_size)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "class Net(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(Net, self).__init__()\n",
    "        self.conv1 = nn.Conv2d(1, 20, 3, padding=(1, 1))\n",
    "        self.conv2 = nn.Conv2d(20, 40, 3, padding=(1, 1))\n",
    "        self.conv3 = nn.Conv2d(40, 80, 3, padding=(1, 1))\n",
    "        self.conv4 = nn.Conv2d(80, 120, 3, padding=(1, 1))\n",
    "        self.conv5 = nn.Conv2d(120, 160, 3, padding=(1, 1))\n",
    "        self.conv6 = nn.Conv2d(160, 200, 3, padding=(1, 1))\n",
    "        self.conv7 = nn.Conv2d(200, 240, 3, padding=(1, 1))\n",
    "        self.pool = nn.MaxPool2d(2, 2)\n",
    "        self.FC = nn.Linear(960, 1080)\n",
    "        self.digitlength = nn.Linear(1080, 7)\n",
    "        self.digit1 = nn.Linear(1080, 10)\n",
    "        self.digit2 = nn.Linear(1080, 10)\n",
    "        self.digit3 = nn.Linear(1080, 10)\n",
    "        self.digit4 = nn.Linear(1080, 10)\n",
    "        self.digit5 = nn.Linear(1080, 10)\n",
    "        \n",
    "        for m in self.modules():\n",
    "            if isinstance(m, nn.Conv2d):\n",
    "                init.kaiming_normal(m.weight)\n",
    "                m.bias.data.zero_()\n",
    "            elif isinstance(m, nn.Linear):\n",
    "                init.kaiming_normal(m.weight)\n",
    "                m.bias.data.zero_()\n",
    "    \n",
    "    def forward(self, x):\n",
    "        x = F.relu(self.conv1(x))\n",
    "        x = self.pool(F.relu(self.conv2(x)))\n",
    "        x = F.relu(self.conv3(x))\n",
    "        x = self.pool(F.relu(self.conv4(x)))\n",
    "        x = self.pool(F.relu(self.conv5(x)))\n",
    "        x = self.pool(F.relu(self.conv6(x)))\n",
    "        x = self.pool(F.relu(self.conv7(x)))\n",
    "        x = x.view(-1, 960)\n",
    "        x = self.FC(x)\n",
    "        yl = self.digitlength(x)\n",
    "        y1 = self.digit1(x)\n",
    "        y2 = self.digit2(x)\n",
    "        y3 = self.digit3(x)\n",
    "        y4 = self.digit4(x)\n",
    "        y5 = self.digit5(x)\n",
    "        return [yl, y1, y2, y3, y4, y5]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "net = Net()\n",
    "net.cuda()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "for param in net.parameters():\n",
    "    if(param.grad is not None):\n",
    "        print(param)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "c3_data_tensor = torch.from_numpy(c3_data)\n",
    "c3_target_tensor = torch.from_numpy(c3_target).type(torch.LongTensor)\n",
    "print(c3_data_tensor.type(), c3_data_tensor.size())\n",
    "print(c3_target_tensor.type(), c3_target_tensor.size())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "c2_data_tensor = torch.from_numpy(c2_data)\n",
    "c2_target_tensor = torch.from_numpy(c2_target).type(torch.LongTensor)\n",
    "print(c2_data_tensor.type(), c2_data_tensor.size())\n",
    "print(c2_target_tensor.type(), c2_target_tensor.size())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "c1_data_tensor = torch.from_numpy(c1_data)\n",
    "c1_target_tensor = torch.from_numpy(c1_target).type(torch.LongTensor)\n",
    "print(c1_data_tensor.type(), c1_data_tensor.size())\n",
    "print(c1_target_tensor.type(), c1_target_tensor.size())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "data_tensor = torch.cat((c3_data_tensor, c2_data_tensor, c1_data_tensor), 0)\n",
    "target_tensor = torch.cat((c3_target_tensor, c2_target_tensor, c1_target_tensor), 0)\n",
    "print(data_tensor.type(), data_tensor.size())\n",
    "print(target_tensor.type(), target_tensor.size())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "objective = nn.CrossEntropyLoss()\n",
    "optimizer = optim.Adam(net.parameters())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "num_epochs = 20\n",
    "batch_size = 64\n",
    "num_train =  c3_data.shape[0]\n",
    "iter_per_epoch = num_train // batch_size\n",
    "print_every = 300\n",
    "print(iter_per_epoch)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "epoch_losses = {i:[] for i in range(num_epochs)}\n",
    "loss_history = []"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "a = torch.randn(10, 5)\n",
    "print(a)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "idx = torch.randperm(5)\n",
    "print(idx)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "b = a[idx]\n",
    "print(b)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "print(id(a))\n",
    "print(id(b))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "b = Variable(b).cuda()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "print(id(b))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "del b"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "print(a)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "for epoch in range(num_epochs):  # loop over the dataset multiple times\n",
    "    \n",
    "    print('Epoch {}/{}'.format(epoch, num_epochs - 1))\n",
    "    print('-' * 110)\n",
    "\n",
    "    i = 0\n",
    "    rng_state = torch.get_rng_state()\n",
    "    new_idxs = torch.randperm(num_train)\n",
    "    \n",
    "    t1 = time.time()\n",
    "    for t in range(iter_per_epoch):\n",
    "        batch_idxs = new_idxs[i: i+batch_size]\n",
    "        X_batch = data_tensor[batch_idxs]\n",
    "        Y_batch = target_tensor[batch_idxs][:,0:4]\n",
    "        i += batch_size\n",
    "\n",
    "        X_batch = Variable(X_batch).cuda()\n",
    "        optimizer.zero_grad()\n",
    "        outputs = net(X_batch)\n",
    "        \n",
    "        if \n",
    "        \n",
    "        lossl = objective(outputs[0], Y_batch[:, 0])\n",
    "        loss1 = objective(outputs[1], Y_batch[:, 1])\n",
    "        loss2 = objective(outputs[2], Y_batch[:, 2])\n",
    "        loss3 = objective(outputs[3], Y_batch[:, 3])\n",
    "        final_loss = lossl + loss1 + loss2 + loss3\n",
    "        \n",
    "        final_loss.backward()\n",
    "        \n",
    "        optimizer.step()\n",
    "        \n",
    "        loss_history.append(final_loss.data[0])\n",
    "        epoch_losses[epoch].append(final_loss.data[0])\n",
    "        \n",
    "        if (t % print_every == 0):\n",
    "            print('Iteration : ', t+1, ' / ', iter_per_epoch)\n",
    "            print('loss : ', final_loss.data[0])\n",
    "            print('lossl : ', lossl.data[0], 'loss1 : ', loss1.data[0], 'loss2 : ', loss2.data[0], 'loss3 : ', loss3.data[0])\n",
    "        \n",
    "    t2 = time.time()\n",
    "    print(\"time taken : \", t2-t1)\n",
    "    print('-' * 110)\n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "arrayidx = torch.randperm(5)\n",
    "print(arrayidx)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "print(a[arrayidx])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "f = open(\"c3_leg_ep_20.pkl\", \"bw\")\n",
    "torch.save(net.state_dict(), f)\n",
    "f.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "plt.figure()\n",
    "plt.plot(loss_history)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
